import torch
import numpy as np
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt

def plot_decision_bundry(model,x,y):
    x_min,x_max = x[:,0].min()-1 , x[:,0].max()+1
    y_min,y_max = x[:,1].min()-1 , x[:,1].max()+1
    h=0.01
    xx,yy=np.meshgrid(np.arange(x_min,x_max,h),np.arange(y_min,y_max,h))
    Z = model(np.c_[xx.ravel(),yy.ravel()])
    Z = Z.reshape(xx.shape)
    plt.contourf(xx,yy,Z, cmap=plt.cm.Spectral)
    plt.ylabel('x2')
    plt.xlabel('x1')
    plt.scatter(x[:, 0], x[:, 1], c=y.reshape(-1), s=40, cmap=plt.cm.Spectral)

np.random.seed(1)
m = 400
N = int(m/2) # 每一类的点的个数
D = 2 # 维度
x = np.zeros((m, D))
y = np.zeros((m, 1), dtype='uint8') # label 向量，0 表示红色，1 表示蓝色
a = 4

for j in range(2):
    ix = range(N*j,N*(j+1))
    t = np.linspace(j*3.12,(j+1)*3.12,N) + np.random.randn(N)*0.2 # theta
    r = a*np.sin(4*t) + np.random.randn(N)*0.2 # radius
    x[ix] = np.c_[r*np.sin(t), r*np.cos(t)]
    y[ix] = j

# plt.scatter(x[:, 0], x[:, 1], c=y.reshape(-1), s=40, cmap=plt.cm.Spectral)
# plt.show()

x = torch.from_numpy(x).float()
y = torch.from_numpy(y).float()

# w = nn.Parameter(torch.randn(2,1))
# b = nn.Parameter(torch.zeros(1))
# #自动更新参数
# optimizer = torch.optim.SGD([w,b],1e-1)

# def logistic_regression(x):
#     return torch.mm(x,w)+b

# criterion = nn.BCEWithLogitsLoss()
# for e in range(100):
#     out = logistic_regression(Variable(x))
#     loss = criterion(out,Variable(y))
#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()
#     if(e+1)%20 == 0:
#         print('epoch: {}, loss: {}'.format(e+1, loss))

# def logistically_plot(x):
#     x = Variable(torch.from_numpy(x).float())
#     out = F.sigmoid(logistic_regression(x))
#     out = (out>0.5)*1
#     return out.data.numpy()

# plot_decision_bundry(lambda x:logistically_plot(x),x.numpy(),y.numpy())
# plt.show()

# w1 = nn.Parameter(torch.randn(2,4))
# b1 = nn.Parameter(torch.zeros(4))

# w2 = nn.Parameter(torch.randn(4,1))
# b2 = nn.Parameter(torch.zeros(1))

# def network(x):
#     x1 = torch.mm(x,w1)+b1
#     x2 = F.tanh(x1)
#     x3 = torch.mm(x2,w2)+b2
#     return x3

# optimazer = torch.optim.SGD([w1,w2,b1,b2],1e-2)
# cratiration = nn.BCEWithLogitsLoss()

# for e in range(10000):
#     out = network(Variable(x))
#     loss = cratiration(out,Variable(y))
#     cratiration.zero_grad()
#     loss.backward()
#     optimazer.step()
#     if (e+1)%1000==0:
#         print('epcho:{},loss:{}'.format(e+1,loss))

# def plot_network(x):
#     x = Variable(torch.from_numpy(x).float())
#     x1 = torch.mm(x,w1)+b1
#     x1 = F.tanh(x1)
#     x2 = torch.mm(x1,w2)+b2
#     out = F.sigmoid(x2)
#     out = (out>0.5)*1
#     return out.data.numpy()

# plot_decision_bundry(lambda x:plot_network(x),x.numpy(),y.numpy())
# plt.show()

# set_qnet = nn.Sequential(
#     nn.Linear(2,4),
#     nn.Tanh(),
#     nn.Linear(4,1)
# )

# para = set_qnet.parameters()

# optim = torch.optim.SGD(para,0.1)
# cratiration = nn.BCEWithLogitsLoss()

# for e in range(30000):
#     out = set_qnet(Variable(x))
#     loss = cratiration(out,Variable(y))
#     cratiration.zero_grad()
#     loss.backward()
#     optim.step()
#     if (e+1)%1000==0:
#         print('epcho:{},loss:{}'.format(e+1,loss))
    
class neunet(nn.Module):
    def __init__(self,num_input,num_output,num_hidden):
        super(neunet,self).__init__()
        self.layer1 = nn.Linear(num_input,num_hidden)
        self.layer2 = nn.Tanh()
        self.layer3 = nn.Linear(num_hidden,num_output)
    def forward(self,x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x
    
double_net = neunet(2,1,6)

opt = torch.optim.SGD(double_net.parameters(),1.)
creatrean = nn.BCEWithLogitsLoss()

for e in range(10000):
    out = double_net(Variable(x))
    loss = creatrean(out,Variable(y))
    opt.zero_grad()
    loss.backward()
    opt.step()
    if (e+1)%1000==0:
        print('epcho:{},loss:{}'.format(e+1,loss))

def plot_net(x):
    out = F.sigmoid(Variable(double_net(torch.from_numpy(x).float()))).data.numpy()
    out = (out>0.5)*1
    return out

plot_decision_bundry(lambda x:plot_net(x),x.numpy(),y.numpy())
plt.show()